{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$\\ell(x, y) = L = \\{l_1,\\dots,l_N\\}^\\top, \\quad\n",
    "        l_n = \\left( x_n - y_n \\right)^2$$\n",
    "$$\\ell(x, y) =\n",
    "        \\begin{cases}\n",
    "            \\operatorname{mean}(L), &  \\text{if reduction} = \\text{`mean';}\\\\\n",
    "            \\operatorname{sum}(L),  &  \\text{if reduction} = \\text{`sum'.}\n",
    "        \\end{cases}$$\n",
    ">tips: 对于torch.nn.MSELoss函数size_average已弃用可使用上式选择是否取均值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import os\n",
    "# os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "w =  2.000006675720215\n",
      "b =  -1.49944453369244e-05\n",
      "y_pred =  8.000011444091797\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYEAAAEGCAYAAACD7ClEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAATiElEQVR4nO3dbZBkV33f8e+vZ3f1LEsbjVSKBF6IVcSKYwReMA8J5Vg2MdhlqWxkiCO8RZToRUgMsSuOFOxQyYsUeXKZFw7xljC1FDIYC2GpMGVQ1kBCqiJ59YCRtCgST0KwaMeOAQmCpNX+86LvLOPZ21ej1bR6+vT3UzV1u+/c7nv+tTvzm3Nun3NTVUiSFtNo1g2QJM2OISBJC8wQkKQFZghI0gIzBCRpgW2bdQM24pxzzqldu3bNuhmSNFduv/32P6+q5aFj5iIEdu3axYEDB2bdDEmaK0m+/FTHOBwkSQvMEJCkBWYISNICMwQkaYEZApK0wAwBSVpghoAkLbCmQ2D/wYf5r598YNbNkKQtq+kQ+MR9h7nuf35x1s2QpC2r6RAYJXjTHEmarOkQCHDUDJCkiaYaAkn+RZJ7ktyd5P1JTk6yM8ktSe7vtmdP8fz2BCRpwNRCIMkFwC8Du6vqh4Al4A3ANcD+qroI2N89nxojQJImm/Zw0DbglCTbgFOBrwGXAfu67+8DLp/WyRNMAUkaMLUQqKqvAv8ZeBA4BHyzqj4OnFdVh7pjDgHnTqsNo8QMkKQB0xwOOpvxX/3PA/46cFqSK5/G669OciDJgZWVlRNrA3DUawKSNNE0h4N+AvhiVa1U1RPAjcArgIeTnA/QbQ/3vbiq9lbV7qravbw8eGOciRIwAyRpsmmGwIPAy5KcmiTApcBB4GZgT3fMHuCmaTUgCeWAkCRNNLXbS1bVrUluAO4AjgB3AnuB04EPJrmKcVBcMa02BHsCkjRkqvcYrqq3A29ft/sxxr2CqYsXhiVpUNszhoOTxSRpQNshgMNBkjSk7RCIc8UkaUjTIeAqopI0rOkQcBVRSRrWdAiMFw+SJE3SdAisRoBDQpLUr+0Q6FLADJCkfk2HwKhLATNAkvo1HQKrw0GuJCpJ/doOAYeDJGlQ4yGwOhxkCkhSn6ZDYJU9AUnq13QIjJwnIEmDmg6B1QzwwrAk9Ws7BLqtGSBJ/doOgdVPB822GZK0ZTUdAscmi9kVkKReTYfAKlcSlaR+TYdAHA+SpEFth0C3dbKYJPVrOwRcNkKSBjUdAq4iKknDmg4BJ4tJ0rC2Q6DbmgGS1K/pEMBVRCVpUNMhcGz5ODNAkno1HQJeGJakYU2HgBeGJWlY2yHQbc0ASerXdgi4aoQkDWo8BFxFVJKGtB0C3dYMkKR+bYfAsZ7AjBsiSVtU2yHQbZ0sJkn92g4BVxGVpEFNh4CTxSRpWNMh4GQxSRrWdAisMgMkqd9UQyDJWUluSPK5JAeTvDzJziS3JLm/2549xfN3j0wBSeoz7Z7AO4E/rqq/CbwQOAhcA+yvqouA/d3zqXCegCQNm1oIJDkTeBXwboCqeryqvgFcBuzrDtsHXD6tNnhhWJKGTbMn8HxgBXhPkjuTXJfkNOC8qjoE0G3P7XtxkquTHEhyYGVl5YQa4IVhSRo2zRDYBrwYeFdVvQj4Nk9j6Keq9lbV7qravby8fEINcDhIkoZNMwQeAh6qqlu75zcwDoWHk5wP0G0PT6sBThaTpGFTC4Gq+jrwlSQv6HZdCtwL3Azs6fbtAW6aVhtW+wIuGyFJ/bZN+f3/OXB9kh3AF4A3MQ6eDya5CngQuGJaJx/ZE5CkQVMNgaq6C9jd861Lp3neVa4iKknDmp4x7CqikjSs7RBwOEiSBjUdAk4Wk6RhTYcAThaTpEFNh4CTxSRpWNsh4CqikjSo7RDotvYEJKlf0yHghWFJGtZ0CBxbRfSoMSBJfdoOgW5rBEhSv6ZDACeLSdKgpkMgriIqSYOaDoGR40GSNKjpEFidJ+B1YUnq13gIjLcOB0lSv7ZDoNt6YViS+rUdAk4Wk6RBjYfAeOsqopLUr+0QWH1gBkhSr7ZDIM4TkKQhbYdAt3U0SJL6NR0Cx1YRNQQkqVfTIeCFYUka1nQIrDICJKlf0yEQVxGVpEFth4B3FJCkQU2HwKirzp6AJPVrOgRWewKuIipJ/doOAVcRlaRBbYdAt3U4SJL6tR0CriIqSYMaD4HxtuwKSFKvtkOg25oBktRvQyGQ5C1JzszYu5PckeTV027cM+UqopI0bKM9gX9UVd8CXg0sA28C3jG1Vm0SewKSNGyjIbD6+/S1wHuq6jNr9m1ZriIqScM2GgK3J/k44xD4WJIzgKPTa9bmcBVRSRq2bYPHXQVcAnyhqr6TZCfjIaG5YARIUr+N9gReDtxXVd9IciXw68A3p9eszRHXj5OkQRsNgXcB30nyQuDXgC8D793IC5MsJbkzyUe65zuT3JLk/m579gm1fGPnBvx0kCRNstEQOFLjGVeXAe+sqncCZ2zwtW8BDq55fg2wv6ouAvZ3z6di5P0EJGnQRkPgkSTXAm8E/ijJErD9qV6U5ELgp4Hr1uy+DNjXPd4HXL7h1j5NriIqScM2GgKvBx5jPF/g68AFwH/awOt+i/Hw0dpPEp1XVYcAuu25fS9McnWSA0kOrKysbLCZ699jvHU4SJL6bSgEul/81wPfl+RngO9W1eA1ge64w1V1+4k0rKr2VtXuqtq9vLx8Im/hZDFJegobXTbiF4DbgCuAXwBuTfK6p3jZK4GfTfIl4APAjyd5H/BwkvO79z0fOHyCbd9IuwE/HCRJk2x0OOhtwEuqak9V/RLwUuA3hl5QVddW1YVVtQt4A/AnVXUlcDOwpztsD3DTCbV8A1xFVJKGbTQERlW19i/2v3gar13vHcBPJrkf+EmmuAaRw0GSNGyjM4b/OMnHgPd3z18PfHSjJ6mqTwKf7B7/BXDpxpt44o4NB5kCktRrQyFQVf8yyc8zHucPsLeqPjzVlm0CJwxL0rCN9gSoqg8BH5piWzbd6iqizhOQpH6DIZDkEfr/kA5QVXXmVFq1SdJdtXA4SJL6DYZAVW10aYgt6Xs9AUNAkvo0fY/hpS4Entzydz6QpNloOgRGXXX2BCSpX9MhsNoTOOqVYUnq1XQIrF4TeNKegCT1ajsERvYEJGlI0yEAsDSKPQFJmqD9EEicLCZJEzQfAonDQZI0SfMhsDQKTxoCktSr/RCI1wQkaZLmQ2A0ivcTkKQJ2g+B4HCQJE3QfAj4EVFJmqz5EBglfjpIkiZoPgSWRnEBOUmaoPkQGCUuJS1JE7QfAiOXkpakSZoPgaU4WUySJmk+BEZeE5CkidoPgRgCkjRJ8yHgcJAkTdZ8CIxGfjpIkiZpPgSWRlAOB0lSr+ZDYOQqopI00WKEgNcEJKlX8yHgshGSNFn7IZBw1AvDktSr+RBI8JqAJE3QfAgsjVxKWpImWYgQsCcgSf2aD4HxshGzboUkbU0LEAI4HCRJEzQfAksj5wlI0iTNh4CriErSZIaAJC2wqYVAkuck+USSg0nuSfKWbv/OJLckub/bnj2tNoDDQZI0ZJo9gSPAr1bVDwIvA96c5GLgGmB/VV0E7O+eT834zmLTPIMkza+phUBVHaqqO7rHjwAHgQuAy4B93WH7gMun1QaApWBPQJImeFauCSTZBbwIuBU4r6oOwTgogHMnvObqJAeSHFhZWTnhc3tNQJImm3oIJDkd+BDw1qr61kZfV1V7q2p3Ve1eXl4+4fOPXDZCkiaaaggk2c44AK6vqhu73Q8nOb/7/vnA4Wm2YcmbykjSRNP8dFCAdwMHq+o313zrZmBP93gPcNO02gDeY1iShmyb4nu/Engj8Nkkd3X7/jXwDuCDSa4CHgSumGIbGMV7DEvSJFMLgar6NJAJ3750Wuddz1VEJWmyhZgx7EdEJalf8yGwNAp2BCSpX/MhMAoc8SbDktSr+RDYvjRyOEiSJliIEHjiyfITQpLUo/kQ2LFtXOLjThaQpOM0HwLbl8afUn3iSXsCkrTeAoTAuMQnjtgTkKT1mg+B1eGgJxwOkqTjNB8Cqz0BrwlI0vGaD4Edq8NBXhOQpOM0HwLHrgnYE5Ck4yxACIw/HfS4F4Yl6Tjth4DzBCRpouZDYIcfEZWkiZoPge1eGJakiRYgBFZnDNsTkKT1mg8B1w6SpMnaDwE/IipJEzUfAs4TkKTJ2g+B1bWDjnhhWJLWaz8EugvDj9kTkKTjNB8CzhOQpMmaDwGvCUjSZM2HwOpHRB+zJyBJx2k+BLYvjdixbcS3Hz8y66ZI0pbTfAgAnHHSNh79riEgSestRAicfvI2Hn3MEJCk9RYiBE7bYU9AkvosRAjYE5CkfgsRAmecZAhIUp+FCAF7ApLUbzFCwE8HSVKvhQmBR+wJSNJxFiIEzjp1B48fOeqQkCStsxAhcMHZpwDw1b/8fzNuiSRtLYsRAmd1IfCN78y4JZK0tSxECFxoT0CSei1ECCyffhInbx/x+ZVvz7opkrSlzCQEkvxUkvuSPJDkmmmfbzQKL9m1k//1wJ9P+1SSNFee9RBIsgT8NvAa4GLgHyS5eNrn/XsvOJf7Dz/KLfc+TJX3G5YkgG0zOOdLgQeq6gsAST4AXAbcO82T/uKPPpffu+1B/sl7D7DztB2cumOJHUsjRqNM87R6mvzXkP6qf/9zf5uX7No5tfefRQhcAHxlzfOHgB9df1CSq4GrAZ773Oc+45OevH2JG//pK/jDO7/KwUPf4rEjR3n8yFHsFGwdhf8Y0nqnbF+a6vvPIgT6/tg77qe/qvYCewF27969Kb8dzjx5O7/08l2b8VaS1IRZXBh+CHjOmucXAl+bQTskaeHNIgT+FLgoyfOS7ADeANw8g3ZI0sJ71oeDqupIkn8GfAxYAn63qu55ttshSZrNNQGq6qPAR2dxbknS9yzEjGFJUj9DQJIWmCEgSQvMEJCkBZZ5WEcnyQrw5RN8+TlASyvHWc/W11pN1rO1DdXz/VW1PPTiuQiBZyLJgaraPet2bBbr2fpaq8l6trZnWo/DQZK0wAwBSVpgixACe2fdgE1mPVtfazVZz9b2jOpp/pqAJGmyRegJSJImMAQkaYE1HQLP9g3tN0OS301yOMnda/btTHJLkvu77dlrvndtV999Sf7+bFo9WZLnJPlEkoNJ7knylm7/XNaU5OQktyX5TFfPv+32z2U9q5IsJbkzyUe653NbT5IvJflskruSHOj2zXM9ZyW5Icnnup+jl29qPVXV5BfjZao/Dzwf2AF8Brh41u3aQLtfBbwYuHvNvv8IXNM9vgb4D93ji7u6TgKe19W7NOsa1tVzPvDi7vEZwP/p2j2XNTG+M97p3ePtwK3Ay+a1njV1/Qrwe8BHGvg/9yXgnHX75rmefcA/7h7vAM7azHpa7gkcu6F9VT0OrN7Qfkurqv8B/N91uy9j/B+Bbnv5mv0fqKrHquqLwAOM694yqupQVd3RPX4EOMj4PtNzWVONPdo93d59FXNaD0CSC4GfBq5bs3tu65lgLutJcibjPwzfDVBVj1fVN9jEeloOgb4b2l8wo7Y8U+dV1SEY/1IFzu32z1WNSXYBL2L81/Pc1tQNndwFHAZuqaq5rgf4LeDXgKNr9s1zPQV8PMntSa7u9s1rPc8HVoD3dMN11yU5jU2sp+UQ2NAN7efc3NSY5HTgQ8Bbq+pbQ4f27NtSNVXVk1V1CeP7Y780yQ8NHL6l60nyM8Dhqrp9oy/p2bdl6um8sqpeDLwGeHOSVw0cu9Xr2cZ4ePhdVfUi4NuMh38medr1tBwCLd3Q/uEk5wN028Pd/rmoMcl2xgFwfVXd2O2e65oAum75J4GfYn7reSXws0m+xHjI9MeTvI/5rYeq+lq3PQx8mPFwyLzW8xDwUNfbBLiBcShsWj0th0BLN7S/GdjTPd4D3LRm/xuSnJTkecBFwG0zaN9EScJ4PPNgVf3mmm/NZU1JlpOc1T0+BfgJ4HPMaT1VdW1VXVhVuxj/jPxJVV3JnNaT5LQkZ6w+Bl4N3M2c1lNVXwe+kuQF3a5LgXvZzHpmfeV7ylfVX8v40yifB9426/ZssM3vBw4BTzBO9auAvwbsB+7vtjvXHP+2rr77gNfMuv099fwdxt3RPwPu6r5eO681AT8M3NnVczfwb7r9c1nPutp+jO99Omgu62E8hv6Z7uue1Z/7ea2na98lwIHu/9wfAmdvZj0uGyFJC6zl4SBJ0lMwBCRpgRkCkrTADAFJWmCGgCQtMENAmrIkP7a6Oqe01RgCkrTADAGpk+TK7l4BdyX5nW6huEeT/JckdyTZn2S5O/aSJP87yZ8l+fDqeu5JfiDJf+/uN3BHkr/Rvf3pa9aEv76bSS3NnCEgAUl+EHg948XHLgGeBP4hcBpwR40XJPsU8PbuJe8F/lVV/TDw2TX7rwd+u6peCLyC8exvGK+e+lbG670/n/GaPdLMbZt1A6Qt4lLgR4A/7f5IP4XxolxHgd/vjnkfcGOS7wPOqqpPdfv3AX/QrVlzQVV9GKCqvgvQvd9tVfVQ9/wuYBfw6alXJT0FQ0AaC7Cvqq79KzuT31h33NA6K0NDPI+tefwk/uxpi3A4SBrbD7wuyblw7J6038/4Z+R13TG/CHy6qr4J/GWSv9vtfyPwqRrfJ+GhJJd373FSklOfzSKkp8u/RiSgqu5N8uuM70g1YryK65sZ38TjbyW5Hfgm4+sGMF6+9791v+S/ALyp2/9G4HeS/LvuPa54FsuQnjZXEZUGJHm0qk6fdTukaXE4SJIWmD0BSVpg9gQkaYEZApK0wAwBSVpghoAkLTBDQJIW2P8HUelMY1EwvNMAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "# prepare dataset\n",
    "# x,y是矩阵，3行1列 也就是说总共有3个数据，每个数据只有1个特征\n",
    "x_data = torch.tensor([[1.0], [2.0], [3.0]])\n",
    "y_data = torch.tensor([[2.0], [4.0], [6.0]])\n",
    " \n",
    "#design model using class\n",
    "\"\"\"\n",
    "our model class should be inherit from nn.Module, which is base class for all neural network modules.\n",
    "member methods __init__() and forward() have to be implemented\n",
    "class nn.linear contain two member Tensors: weight and bias\n",
    "class nn.Linear has implemented the magic method __call__(),which enable the instance of the class can\n",
    "be called just like a function.Normally the forward() will be called \n",
    "\"\"\"\n",
    "class LinearModel(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LinearModel, self).__init__()\n",
    "        # (1,1)是指输入x和输出y的特征维度，这里数据集中的x和y的特征都是1维的\n",
    "        # 该线性层需要学习的参数是w和b  获取w/b的方式分别是~linear.weight/linear.bias\n",
    "        self.linear = torch.nn.Linear(1, 1)\n",
    " \n",
    "    def forward(self, x):\n",
    "        y_pred = self.linear(x)\n",
    "        return y_pred\n",
    " \n",
    "model = LinearModel()\n",
    " \n",
    "# construct loss and optimizer\n",
    "# criterion = torch.nn.MSELoss(size_average = False)\n",
    "criterion = torch.nn.MSELoss(reduction = 'sum')\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr = 0.01) # model.parameters()自动完成参数的初始化操作\n",
    "# optimizer = torch.optim.Adagrad(model.parameters(), lr = 0.01)\n",
    "epoch = 0\n",
    "epoch_list = []\n",
    "loss_list = []\n",
    "# training cycle forward, backward, update\n",
    "while True:\n",
    "    y_pred = model(x_data) # forward:predict\n",
    "    loss = criterion(y_pred, y_data) # forward: loss\n",
    "    loss.backward() # backward: autograd，自动计算梯度\n",
    "    optimizer.step() # update 参数，即更新w和b的值\n",
    "    optimizer.zero_grad() # the grad computer by .backward() will be accumulated. so before backward, remember set the grad to zero\n",
    "    # print(epoch, loss.item())    \n",
    "    epoch_list.append(epoch)\n",
    "    loss_list.append(loss.item())\n",
    "    epoch += 1\n",
    "    if(loss.item() < 1e-10):\n",
    "        break\n",
    "\n",
    "print('w = ', model.linear.weight.item())\n",
    "print('b = ', model.linear.bias.item())\n",
    " \n",
    "x_test = torch.tensor([[4.0]])\n",
    "y_test = model(x_test)\n",
    "print('y_pred = ', y_test.item())\n",
    "\n",
    "plt.plot(epoch_list, loss_list)\n",
    "plt.ylabel('loss')\n",
    "plt.xlabel('epoch')\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "44143c5ac5f3ceb8e37c69c3af73325ae55d21292b2c7b54871fd886482dde4c"
  },
  "kernelspec": {
   "display_name": "Python 3.6.13 ('pytorch')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
